# This programme is used to train a Res2Net model on ciffar100 dataset.

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
import argparse
import torch.optim
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data as data
import matplotlib.pyplot as plt
import numpy as np

import res2net as res2net
import SKNet.SKNet as SKNet


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
num_workers = 2


def cifar100_dataset(args):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  # 数据增强
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root=args.data_path, train=True,
                                                      download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(cifar100_training, batch_size=args.batch_size,
                                               shuffle=True, num_workers=num_workers)

    cifar100_testing = torchvision.datasets.CIFAR100(root=args.data_path, train=False,
                                                     download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(cifar100_testing, batch_size=100,
                                              shuffle=False, num_workers=num_workers)

    return train_loader, test_loader


def train(net, train_loader, optimizer, criterion, writer, args, epoch, index_num):
    net.train()
    train_tqdm = tqdm(train_loader, desc="Epoch " + str(epoch))
    loss_sum = 0.0
    for index, (inputs, labels) in enumerate(train_tqdm):
        optimizer.zero_grad()
        outputs = net(inputs.to(args.device))
        loss = criterion(outputs, labels.to(args.device))
        loss.backward()
        optimizer.step()
        # writer.add_scalar("loss/train", loss, index_num)
        loss_sum += loss
        index_num = index_num + 1
        train_tqdm.set_postfix({"loss": "%.3g" % loss.item()})
    writer.add_scalar("loss/train", loss_sum / index_num, epoch)


def validate(net, test_loader, criterion, writer, args, epoch, loss_vector, accuracy_vector):
    net.eval()
    val_loss, correct = 0, 0
    for index, (data, target) in enumerate(test_loader):
        data = data.to(args.device)
        target = target.to(args.device)
        output = net(data)
        val_loss += criterion(output, target.to(args.device)).data.item()
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum()

    val_loss /= len(test_loader)
    loss_vector.append(val_loss)
    writer.add_scalar("loss/validation", val_loss, epoch)

    accuracy = 100. * correct.to(torch.float32) / len(test_loader.dataset)
    accuracy_vector.append(accuracy)
    writer.add_scalar("accuracy/validation", accuracy, epoch)

    print("***** Eval results *****")
    print('epoch: {}, Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
        epoch, val_loss, correct, len(test_loader.dataset), accuracy))
    return correct


MODEL_NAME = 'cifar_res2net50_26w_4s.pth'
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", default="./data", type=str, help="The input data dir")
    parser.add_argument("--batch_size", default=128, type=int, help="The batch size of training")
    parser.add_argument("--device", default='cuda', type=str, help="The training device")
    parser.add_argument("--learning_rate", default=0.1, type=float, help="learning rate")
    parser.add_argument("--epochs", default=150, type=int, help="Training epoch")
    parser.add_argument("--modeldir", default="./model", type=str)
    args = parser.parse_known_args()[0]

    train_loader, test_loader = cifar100_dataset(args)

    writer = SummaryWriter(os.path.join(args.modeldir, "tensorboard"))
    # net = res2net.res2net101_26w_4s().to(args.device)
    net = res2net.res2net50().to(args.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, weight_decay=5e-4, momentum=0.9)

    lossv, accv = [], []
    index_num = 0
    correct_max = 0
    for epoch in range(args.epochs):
        train(net, train_loader, optimizer, criterion, writer, args, epoch, index_num)
        PATH = os.path.join(args.modeldir, MODEL_NAME)
        torch.save(net.state_dict(), PATH)
        with torch.no_grad():
            correct = validate(net, test_loader, criterion, writer, args, epoch, lossv, accv)
        if correct > correct_max:
            torch.save(net.state_dict(), PATH)
            correct_max = correct
        elif optimizer.param_groups[0]['lr'] > 0.002:
            optimizer.param_groups[0]['lr'] /= 2
    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, args.epochs + 1), lossv)
    plt.title('validation loss')
    plt.savefig(os.path.join(args.modeldir, 'validation_loss'))

    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, args.epochs + 1), accv)
    plt.title('validation accuracy')
    plt.savefig(os.path.join(args.modeldir, 'validation_accuracy'))
